/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
#include "tensorflow/lite/micro/test_helpers.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
namespace tflite {
namespace testing {
namespace {

// naming as follows: <tensor name>_<input size>x<batch size>x<batch count>

// 10 inputs each with shape {2, 2}.
const float input_data_2x2x10[] = {
    0.12609188,  -0.46347019, 0.35867718,  0.36897406,

    0.14278367,  -1.64410412, -0.57290924, 0.12729003,

    0.49837467,  0.19278903,  0.17660543,  0.52949083,

    -0.11186574, 0.13164264,  -0.72674477, -0.5683046,

    -0.68892461, 0.37783599,  -0.63690937, 0.44483393,

    -0.81299269, -0.86831826, -0.95760226, 1.82078898,

    -1.45006323, -0.82251364, -1.65087092, -1.89238167,

    0.03966608,  -0.24936394, 2.06740379,  -1.51439476,

    0.11771342,  -0.23761693, 0.31088525,  -1.55601168,

    -0.89477462, 1.67204106,  -0.6230064,  0.29819036,
};

// Feature filter of shape {8, 2}.
const float feature_weights_data_2x2x10[] = {
    -0.31930989, 0.0079667,  0.39296314,  0.37613347,  0.12416199,  0.15785322,
    0.27901134,  0.3905206,  0.21931258,  -0.36137494, -0.10640851, 0.31053296,
    -0.36118156, -0.0976817, -0.36916667, 0.22197971};

// Time filter of shape {8, 10}.
const float time_weights_data_2x2x10[] = {
    -0.31930989, 0.37613347,  0.27901134,  -0.36137494, -0.36118156,
    0.22197971,  0.27557442,  -0.06634006, 0.0079667,   0.12416199,

    0.3905206,   -0.10640851, -0.0976817,  0.15294972,  0.39635518,
    -0.02702999, 0.39296314,  0.15785322,  0.21931258,  0.31053296,

    -0.36916667, 0.38031587,  -0.21580373, 0.27072677,  0.23622236,
    0.34936687,  0.18174365,  0.35907319,  -0.17493086, 0.324846,

    -0.10781813, 0.27201805,  0.14324132,  -0.23681851, -0.27115166,
    -0.01580888, -0.14943552, 0.15465137,  0.09784451,  -0.0337657,

    -0.14884081, 0.19931212,  -0.36002168, 0.34663299,  -0.11405486,
    0.12672701,  0.39463779,  -0.07886535, -0.06384811, 0.08249187,

    -0.26816407, -0.19905911, 0.29211238,  0.31264046,  -0.28664589,
    0.05698794,  0.11613581,  0.14078894,  0.02187902,  -0.21781836,

    -0.15567942, 0.08693647,  -0.38256618, 0.36580828,  -0.22922277,
    -0.0226903,  0.12878349,  -0.28122205, -0.10850525, -0.11955214,

    0.27179423,  -0.04710215, 0.31069002,  0.22672787,  0.09580326,
    0.08682203,  0.1258215,   0.1851041,   0.29228821,  0.12366763};

// Activation state with shape {2, 80}. These initial values must be copied into
// a mutable activation state tensor.

const float initial_activation_state_data_2x2x10[] = {
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};

// Bias with shape {8}
const float bias_data_2x2x10[] = {0, 0, 0, 0, 0, 0, 0, 0};

// 10 outputs each of shape {2, 4}
const float golden_output_2x2x10[] = {
    -0.044205, -0.013757, 0.050369,  -0.018447,
    0.073010,  0.025142,  -0.021154, 0.013551,

    -0.209613, -0.062421, 0.150209,  -0.108334,
    0.028256,  -0.006950, -0.030885, 0.009603,

    -0.076800, -0.037075, -0.087198, -0.155183,
    0.091069,  0.098446,  -0.016083, 0.106475,

    -0.082123, -0.162238, -0.084434, -0.141074,
    -0.029340, -0.090685, 0.053302,  -0.030604,

    -0.201440, 0.088424,  0.139877,  0.012416,
    -0.113212, 0.103893,  -0.100842, 0.122780,

    -0.166632, -0.116705, 0.175298,  -0.047163,
    0.313077,  -0.166485, -0.285860, 0.129069,

    -0.625911, 0.046134,  0.138081,  -0.129581,
    -0.521455, -0.061579, 0.230289,  0.114963,

    -0.216693, -0.161643, -0.179177, -0.052599,
    -0.213239, 0.029502,  0.260858,  0.275045,

    -0.213689, -0.323608, -0.285635, -0.317687,
    -0.324092, -0.317972, -0.208450, -0.462504,

    -0.255126, -0.218576, -0.041528, 0.179421,
    -0.440583, 0.072127,  -0.284136, 0.241570};

// Simulated real-world inputs, weights and expected outputs.

// Input of shape {1x16}
const float input_data_16x1x1[] = {
    -0.488494, 2.023762,  -2.233117, -0.488494, 3.559030, 9.490748,
    -3.210106, -1.953977, -0.279140, 0.907204,  1.674838, 0.000000,
    -0.279140, -0.628064, -0.069785, -0.628064,
};

// Feature filter of shape {64, 16}.
const float feature_weights_data_16x1x1[] = {
    0.173588,  0.173588,  -0.024798, 0.193426,  -0.099193, 0.044637,  0.183507,
    0.183507,  0.044637,  0.198386,  -0.069435, 0.084314,  0.312458,  0.024798,
    0.173588,  -0.049596, -0.352135, -0.550521, -0.009919, -0.099193, -0.074395,
    -0.128951, 0.193426,  0.357095,  -0.317418, -0.119032, -0.218225, -0.004960,
    -0.386853, -0.133911, 0.252942,  -0.019839, -0.024798, -0.054556, -0.069435,
    -0.128951, 0.029758,  -0.099193, -0.312458, -0.029758, 0.064475,  0.183507,
    0.114072,  -0.178547, -0.247982, -0.119032, 0.243023,  -0.119032, -0.034718,
    -0.178547, 0.019839,  0.128951,  -0.223184, -0.009919, -0.213265, 0.168628,
    -0.143830, -0.322377, -0.218225, -0.193426, -0.252942, -0.049596, 0.064475,
    -0.267821, -0.580279, -0.099193, 0.213265,  0.119032,  -0.119032, -0.178547,
    0.610037,  0.109112,  0.049596,  -0.014879, -0.049596, -0.193426, 0.039677,
    -0.148789, -0.114072, -0.158709, -0.158709, 0.094233,  0.099193,  -0.114072,
    0.104153,  -0.123991, 0.198386,  -0.173588, 0.089274,  -0.247982, -0.054556,
    0.123991,  0.183507,  0.114072,  0.188467,  0.302539,  0.044637,  0.039677,
    -0.099193, 0.168628,  -0.024798, -0.054556, -0.109112, 0.014879,  -0.009919,
    0.069435,  -0.396772, -0.287660, -0.079354, -0.104153, 0.054556,  0.089274,
    -0.099193, 0.114072,  0.034718,  0.119032,  0.282700,  -0.119032, -0.505884,
    -0.233104, -0.114072, -0.257902, -0.233104, -0.178547, 0.153749,  0.128951,
    0.143830,  -0.188467, -0.183507, 0.104153,  -0.024798, 0.193426,  -0.287660,
    0.168628,  -0.009919, 0.119032,  -0.024798, -0.099193, -0.203346, 0.099193,
    0.084314,  -0.168628, 0.123991,  -0.148789, 0.114072,  -0.029758, 0.228144,
    -0.238063, 0.089274,  -0.064475, 0.307498,  -0.188467, -0.004960, -0.252942,
    -0.173588, -0.158709, -0.044637, -0.009919, 0.312458,  -0.262861, 0.059516,
    0.158709,  0.069435,  -0.282700, 0.074395,  -0.322377, -0.183507, -0.123991,
    -0.233104, 0.009919,  0.252942,  -0.243023, 0.555481,  -0.099193, -0.119032,
    -0.441409, 0.148789,  0.084314,  -0.168628, -0.183507, 0.188467,  0.024798,
    -0.302539, 0.223184,  0.143830,  -0.193426, -0.054556, -0.218225, -0.297579,
    0.104153,  0.272781,  -0.034718, 0.114072,  -0.059516, 0.044637,  0.342216,
    0.421570,  0.138870,  -0.024798, -0.039677, -0.163668, -0.034718, 0.396772,
    -0.128951, -0.044637, -0.173588, 0.302539,  0.079354,  0.049596,  0.133911,
    -0.029758, -0.312458, -0.029758, 0.079354,  0.128951,  0.252942,  0.213265,
    0.014879,  0.287660,  0.178547,  0.297579,  0.352135,  0.401732,  0.024798,
    -0.277740, -0.411651, -0.069435, 0.342216,  -0.158709, -0.104153, -0.009919,
    0.223184,  0.228144,  -0.019839, 0.059516,  -0.104153, -0.510844, 0.029758,
    -0.406691, 0.089274,  0.421570,  0.163668,  -0.143830, -0.019839, -0.039677,
    0.104153,  -0.044637, -0.128951, 0.203346,  0.079354,  -0.069435, 0.094233,
    -0.138870, 0.466207,  -0.163668, 0.049596,  0.029758,  0.267821,  0.029758,
    -0.049596, 0.009919,  0.004960,  -0.099193, 0.094233,  -0.262861, 0.089274,
    -0.302539, 0.332297,  -0.307498, -0.014879, 0.168628,  -0.094233, -0.272781,
    0.034718,  -0.133911, -0.228144, 0.094233,  0.257902,  -0.228144, 0.153749,
    -0.054556, -0.252942, 0.054556,  0.218225,  -0.054556, 0.302539,  0.282700,
    0.054556,  -0.044637, -0.133911, 0.233104,  -0.049596, 0.411651,  0.044637,
    -0.297579, -0.029758, -0.114072, 0.114072,  -0.580279, 0.079354,  -0.024798,
    -0.347175, -0.128951, -0.099193, 0.238063,  -0.104153, -0.009919, 0.158709,
    -0.034718, 0.123991,  -0.163668, 0.059516,  0.342216,  0.009919,  0.064475,
    -0.307498, -0.520763, -0.238063, 0.163668,  0.362054,  0.034718,  -0.178547,
    -0.104153, -0.257902, 0.322377,  0.054556,  0.148789,  -0.178547, 0.084314,
    0.004960,  0.257902,  0.029758,  0.079354,  -0.223184, -0.193426, 0.282700,
    0.000000,  -0.019839, -0.114072, 0.491005,  -0.193426, -0.029758, -0.243023,
    0.009919,  0.089274,  -0.277740, -0.089274, 0.104153,  0.337256,  0.138870,
    -0.307498, -0.054556, 0.352135,  0.133911,  -0.044637, 0.133911,  -0.089274,
    -0.357095, -0.272781, 0.069435,  0.059516,  -0.109112, 0.148789,  -0.044637,
    -0.019839, -0.153749, 0.123991,  -0.223184, 0.322377,  0.074395,  -0.312458,
    0.024798,  -0.223184, 0.109112,  -0.138870, 0.218225,  -0.074395, -0.406691,
    0.009919,  -0.198386, -0.009919, 0.416611,  0.178547,  0.148789,  0.133911,
    -0.004960, 0.069435,  -0.054556, -0.044637, 0.297579,  0.059516,  -0.456288,
    -0.148789, -0.004960, 0.054556,  0.094233,  -0.104153, 0.198386,  -0.302539,
    0.133911,  0.411651,  0.054556,  0.525723,  -0.089274, 0.079354,  0.238063,
    0.079354,  -0.039677, 0.039677,  0.029758,  0.332297,  -0.014879, -0.367014,
    -0.143830, -0.123991, -0.064475, 0.014879,  0.173588,  -0.168628, 0.386853,
    0.009919,  0.173588,  0.163668,  0.123991,  0.163668,  0.198386,  0.203346,
    -0.401732, -0.009919, 0.272781,  -0.173588, 0.044637,  0.238063,  0.133911,
    0.049596,  0.208305,  -0.024798, 0.049596,  -0.049596, 0.034718,  -0.446368,
    0.466207,  -0.089274, -0.099193, -0.128951, -0.228144, 0.014879,  -0.252942,
    0.074395,  -0.223184, -0.168628, -0.292619, 0.178547,  0.153749,  -0.014879,
    0.054556,  0.000000,  0.193426,  0.158709,  0.178547,  -0.327337, -0.138870,
    -0.114072, 0.168628,  0.297579,  -0.109112, -0.029758, -0.029758, -0.416611,
    0.059516,  0.000000,  -0.168628, -0.322377, 0.238063,  -0.128951, -0.029758,
    0.500925,  0.292619,  0.123991,  -0.099193, 0.074395,  0.317418,  -0.148789,
    0.064475,  -0.104153, -0.044637, -0.094233, 0.188467,  -0.044637, 0.213265,
    -0.233104, -0.049596, 0.004960,  -0.198386, 0.287660,  -0.148789, -0.257902,
    0.004960,  -0.218225, -0.044637, -0.386853, -0.243023, -0.163668, 0.094233,
    0.029758,  -0.019839, -0.009919, -0.143830, -0.158709, 0.158709,  -0.243023,
    -0.039677, -0.297579, 0.069435,  0.049596,  0.302539,  0.059516,  0.074395,
    -0.019839, 0.352135,  -0.019839, -0.138870, -0.178547, -0.243023, 0.233104,
    0.252942,  -0.228144, -0.049596, 0.173588,  0.173588,  -0.074395, -0.034718,
    -0.292619, 0.362054,  0.183507,  0.243023,  -0.203346, -0.044637, 0.054556,
    0.059516,  -0.158709, -0.158709, 0.000000,  0.327337,  0.119032,  0.034718,
    -0.044637, -0.089274, 0.089274,  -0.233104, 0.000000,  -0.317418, 0.371974,
    0.213265,  0.307498,  -0.178547, -0.367014, 0.039677,  -0.059516, 0.168628,
    -0.014879, 0.143830,  0.123991,  -0.084314, -0.332297, -0.416611, 0.183507,
    0.109112,  -0.039677, 0.014879,  0.292619,  -0.213265, -0.054556, 0.004960,
    0.123991,  0.119032,  0.000000,  -0.332297, -0.312458, -0.198386, -0.213265,
    0.119032,  0.322377,  0.168628,  0.104153,  -0.262861, 0.327337,  -0.049596,
    -0.228144, -0.074395, 0.168628,  0.123991,  0.396772,  0.044637,  0.322377,
    0.193426,  0.267821,  -0.178547, 0.297579,  0.148789,  -0.218225, -0.138870,
    0.044637,  0.049596,  0.133911,  0.064475,  0.069435,  0.064475,  -0.158709,
    -0.044637, -0.173588, 0.267821,  0.327337,  0.079354,  -0.228144, 0.029758,
    0.014879,  0.198386,  -0.109112, -0.133911, 0.431490,  0.099193,  0.421570,
    0.233104,  -0.054556, 0.054556,  -0.317418, -0.133911, -0.123991, -0.287660,
    0.342216,  -0.049596, -0.153749, 0.228144,  -0.213265, 0.262861,  0.406691,
    -0.084314, -0.004960, 0.193426,  0.188467,  -0.099193, -0.223184, 0.163668,
    -0.257902, -0.153749, 0.441409,  0.099193,  0.128951,  -0.089274, -0.208305,
    -0.009919, -0.004960, -0.109112, 0.024798,  -0.119032, 0.019839,  0.391812,
    -0.024798, 0.198386,  0.327337,  -0.505884, -0.099193, 0.510844,  -0.148789,
    0.094233,  -0.153749, -0.039677, 0.352135,  0.272781,  -0.228144, -0.287660,
    -0.272781, 0.148789,  0.277740,  0.074395,  0.109112,  -0.064475, 0.044637,
    0.074395,  -0.292619, 0.153749,  -0.064475, -0.114072, 0.198386,  -0.039677,
    -0.128951, -0.004960, 0.257902,  -0.228144, -0.094233, 0.064475,  0.014879,
    0.188467,  -0.416611, 0.099193,  0.362054,  -0.208305, 0.198386,  -0.079354,
    0.009919,  0.119032,  0.332297,  0.243023,  -0.168628, 0.158709,  0.039677,
    0.143830,  0.277740,  -0.168628, 0.009919,  0.099193,  -0.004960, -0.257902,
    -0.297579, 0.208305,  -0.104153, 0.119032,  0.247982,  0.381893,  -0.223184,
    -0.367014, -0.327337, -0.168628, -0.094233, 0.208305,  -0.019839, 0.183507,
    0.084314,  0.133911,  0.109112,  -0.148789, -0.183507, -0.411651, -0.024798,
    -0.114072, -0.029758, -0.009919, 0.173588,  -0.059516, -0.049596, 0.039677,
    0.317418,  0.138870,  -0.247982, -0.084314, 0.158709,  0.054556,  -0.084314,
    -0.049596, 0.074395,  0.019839,  -0.282700, -0.119032, -0.262861, 0.163668,
    -0.069435, -0.064475, -0.059516, 0.094233,  0.123991,  -0.079354, -0.272781,
    -0.267821, 0.233104,  0.114072,  -0.218225, 0.540602,  0.089274,  0.262861,
    0.079354,  0.267821,  -0.119032, -0.109112, -0.128951, 0.128951,  -0.044637,
    -0.272781, 0.277740,  0.297579,  -0.054556, -0.084314, -0.049596, 0.123991,
    0.059516,  0.238063,  -0.168628, -0.009919, 0.163668,  -0.307498, 0.109112,
    -0.064475, 0.218225,  -0.168628, -0.004960, -0.168628, 0.119032,  0.094233,
    -0.183507, -0.089274, -0.292619, -0.094233, 0.064475,  -0.183507, -0.168628,
    0.089274,  0.074395,  -0.367014, -0.024798, -0.069435, 0.119032,  -0.302539,
    -0.376933, -0.123991, -0.009919, -0.069435, -0.208305, -0.119032, 0.014879,
    -0.183507, -0.238063, 0.163668,  -0.332297, -0.148789, -0.391812, -0.024798,
    -0.133911, -0.059516, -0.123991, 0.123991,  -0.292619, -0.044637, 0.059516,
    -0.069435, 0.049596,  -0.069435, 0.034718,  0.158709,  -0.347175, -0.044637,
    0.352135,  -0.347175, -0.282700, -0.054556, 0.307498,  0.029758,  0.357095,
    -0.148789, 0.208305,  -0.317418, 0.009919,  0.004960,  -0.243023, 0.049596,
    -0.099193, 0.213265,  -0.342216, 0.158709,  0.123991,  -0.332297, 0.386853,
    -0.262861, -0.208305, 0.123991,  -0.044637, 0.148789,  0.084314,  -0.297579,
    -0.307498, -0.163668, 0.337256,  -0.014879, 0.074395,  0.178547,  -0.004960,
    -0.257902, -0.019839, -0.228144, -0.034718, -0.277740, -0.158709, -0.119032,
    -0.153749, 0.629876,  0.277740,  0.178547,  -0.267821, -0.004960, 0.247982,
    0.084314,  -0.094233, 0.000000,  -0.039677, 0.332297,  0.178547,  0.009919,
    -0.213265, -0.208305, -0.044637, 0.019839,  0.218225,  -0.297579, 0.014879,
    -0.247982, -0.004960, -0.128951, 0.421570,  -0.059516, 0.362054,  -0.203346,
    -0.143830, -0.099193, -0.024798, 0.094233,  -0.123991, 0.163668,  0.109112,
    -0.104153, -0.233104, 0.009919,  -0.218225, 0.376933,  0.104153,  -0.059516,
    0.049596,  -0.054556, 0.019839,  -0.044637, -0.019839, 0.371974,  -0.019839,
    0.104153,  0.168628,  -0.024798, -0.272781, -0.158709, 0.223184,  0.044637,
    0.039677,  -0.168628, -0.287660, -0.109112, 0.094233,  -0.089274, -0.148789,
    0.178547,  -0.039677, -0.089274, -0.049596, -0.024798, 0.064475,  -0.158709,
    0.089274,  0.029758,  -0.247982, 0.362054,  0.024798,  -0.004960, -0.099193,
    0.173588,  -0.059516, 0.188467,  -0.629876, 0.094233,  0.371974,  0.069435,
    0.252942,  -0.357095, -0.272781, -0.367014, 0.014879,  -0.049596, -0.262861,
    0.009919,  -0.094233, -0.094233, 0.059516,  0.223184,  0.133911,  0.411651,
    -0.044637, -0.044637, 0.109112,  0.228144,  0.386853,  -0.233104, 0.069435,
    0.228144,  -0.302539, 0.029758,  0.089274,  0.044637,  -0.238063, -0.138870,
    -0.158709, -0.019839, 0.049596,  0.039677,  0.000000,  -0.069435, 0.109112,
    -0.213265, -0.188467, -0.262861, -0.267821, -0.094233, 0.133911,  0.391812,
    0.123991,  -0.317418, 0.233104,  -0.029758, -0.099193, -0.193426, 0.074395,
    -0.009919, 0.252942,  0.322377,  -0.530683, 0.208305,  0.252942,  0.203346,
    -0.069435, -0.262861};

// Time filter of shape {64, 8}.
const float time_weights_data_16x1x1[] = {
    -0.052026, 0.043107,  0.053512,  0.013378,  0.011892,  -0.182834, -0.108511,
    0.153105,  0.050539,  -0.173915, 0.145672,  0.208103,  -0.221481, 0.108511,
    -0.496475, 0.181347,  -0.016351, -0.132294, -0.234859, -0.243778, 0.028243,
    -0.228914, -0.130808, -0.167969, -0.041621, -0.306209, -0.193239, -0.028243,
    -0.057972, -0.057972, -0.497962, 0.054999,  0.181347,  0.047566,  -0.099592,
    -0.111484, -0.130808, -0.071350, 0.380532,  0.010405,  0.041621,  0.052026,
    0.022297,  0.081755,  0.098106,  0.099592,  -0.584176, -0.023783, 0.062431,
    -0.090674, -0.279453, -0.486070, -0.273507, 0.004459,  -0.062431, 0.095133,
    0.056485,  0.022297,  -0.105538, -0.184320, 0.358235,  0.254183,  0.049053,
    0.084728,  0.218508,  0.078782,  -0.136754, -0.017837, -0.124862, -0.118916,
    -0.001486, 0.043107,  0.254183,  0.087701,  0.261616,  0.309182,  -0.404315,
    -0.040134, -0.046080, -0.052026, -0.034188, -0.475665, -0.025270, -0.049053,
    -0.046080, -0.062431, 0.020810,  0.040134,  -0.135267, -0.169456, -0.050539,
    -0.576743, 0.034188,  0.075809,  0.101079,  0.136754,  0.083241,  0.077296,
    -0.050539, 0.761064,  -0.335938, -0.080268, 0.025270,  0.257156,  0.227427,
    0.252697,  0.065404,  0.115943,  0.222968,  -0.026756, -0.054999, 0.107025,
    -0.093646, 0.041621,  -0.092160, -0.474178, -0.016351, 0.004459,  0.049053,
    0.019324,  0.019324,  0.074323,  0.038648,  -0.613905, 0.182834,  0.075809,
    0.028243,  0.019324,  0.010405,  -0.011892, 0.001486,  -0.492016, -0.224454,
    -0.474178, -0.147159, 0.002973,  0.102565,  0.136754,  -0.267561, -0.001486,
    -0.095133, -0.040134, 0.066890,  0.074323,  0.104052,  0.532150,  0.090674,
    0.072836,  -0.053512, -0.004459, 0.020810,  0.046080,  0.062431,  0.477151,
    0.133781,  -0.029729, -0.026756, 0.031215,  0.156077,  0.096619,  0.251210,
    0.352289,  0.657012,  0.047566,  -0.014865, -0.072836, -0.016351, 0.008919,
    -0.053512, 0.016351,  0.300263,  0.047566,  0.020810,  0.169456,  0.001486,
    0.007432,  0.111484,  0.044594,  -0.188779, -0.096619, 0.074323,  -0.040134,
    0.160537,  0.138240,  0.184320,  0.377559,  -0.092160, -0.049053, 0.056485,
    -0.032702, 0.001486,  -0.083241, -0.472692, -0.114457, -0.117430, -0.075809,
    0.026756,  0.163510,  0.172428,  0.127835,  -0.199185, -0.218508, -0.057972,
    -0.132294, -0.162023, -0.019324, -0.245265, -0.395396, -0.254183, 0.084728,
    0.248238,  0.191752,  0.221481,  0.173915,  0.173915,  -0.208103, -0.077296,
    0.384991,  -0.313641, -0.313641, -0.147159, -0.090674, 0.035675,  0.059458,
    -0.010405, 0.019324,  0.087701,  0.016351,  0.037161,  0.469719,  -0.074323,
    0.092160,  0.026756,  0.090674,  0.098106,  0.004459,  -0.034188, 0.492016,
    -0.367154, -0.093646, -0.063917, 0.041621,  0.017837,  0.026756,  -0.062431,
    -0.350803, 0.425125,  0.002973,  0.083241,  0.075809,  0.016351,  0.047566,
    -0.185807, -0.107025, -0.098106, -0.144186, 0.255670,  0.020810,  0.105538,
    0.029729,  0.129321,  0.156077,  0.141213,  0.334452,  0.147159,  -0.066890,
    0.035675,  0.115943,  0.240805,  0.328506,  0.162023,  -0.237832, 0.218508,
    0.233373,  0.214049,  0.099592,  0.026756,  -0.322560, -0.236346, -0.166483,
    0.225941,  0.109997,  -0.147159, 0.147159,  -0.266075, 0.111484,  0.078782,
    -0.120403, 0.022297,  -0.075809, -0.148645, -0.251210, -0.176888, -0.044594,
    -0.023783, 0.016351,  0.026756,  -0.013378, -0.069863, -0.112970, 0.013378,
    0.086214,  0.014865,  0.352289,  -0.240805, -0.135267, -0.114457, -0.472692,
    0.334452,  0.095133,  0.047566,  0.130808,  -0.068377, -0.007432, -0.130808,
    -0.121889, -0.053512, -0.245265, -0.371613, -0.083241, 0.000000,  -0.028243,
    0.029729,  -0.093646, -0.004459, -0.038648, -0.108511, -0.475665, -0.169456,
    -0.047566, -0.010405, -0.114457, -0.353776, -0.034188, -0.044594, 0.041621,
    -0.047566, -0.107025, 0.004459,  0.053512,  0.047566,  -0.358235, -0.193239,
    0.040134,  -0.096619, -0.054999, 0.099592,  0.032702,  0.205130,  -0.170942,
    -0.237832, -0.405801, -0.126348, -0.072836, -0.203644, -0.169456, -0.093646,
    -0.074323, 0.078782,  0.607959,  -0.437017, -0.164996, -0.166483, 0.043107,
    -0.016351, 0.258643,  0.065404,  -0.057972, 0.017837,  0.080268,  0.050539,
    -0.013378, -0.215536, -0.524718, 0.260129,  0.040134,  -0.002973, -0.046080,
    0.020810,  0.025270,  0.145672,  0.515799,  0.233373,  0.011892,  0.139727,
    0.126348,  0.065404,  -0.007432, -0.008919, 0.035675,  0.083241,  0.040134,
    -0.005946, 0.503907,  -0.490529, -0.181347, -0.092160, -0.038648, 0.019324,
    0.133781,  -0.011892, 0.041621,  0.062431,  -0.062431, -0.040134, -0.092160,
    -0.111484, -0.133781, -0.130808, -0.484583, -0.248238, 0.037161,  -0.092160,
    -0.056485, -0.041621, 0.112970,  0.248238,  0.438503,  0.258643,  -0.013378,
    0.004459,  0.043107,  0.040134,  0.017837,  0.101079,  0.264589,  0.212563,
    0.014865,  0.285399,  0.153105,  0.170942,  0.358235,  0.334452,  0.086214,
    0.132294,  0.098106,  -0.001486, 0.107025,  0.200671,  -0.026756, 0.344857,
    0.227427,  -0.041621, 0.098106,  0.063917,  -0.093646, 0.130808,  0.285399,
    -0.319587, 0.035675,  -0.017837, -0.319587, 0.016351,  -0.098106, -0.017837,
    0.083241,  0.074323,  -0.054999, 0.276480,  0.316614,  -0.099592, -0.059458,
    0.156077,  -0.043107, 0.035675,  0.056485,  -0.022297, 0.017837,  -0.001486,
    0.340398,  0.492016,  0.004459,  0.057972,  -0.150132, -0.206617, -0.257156,
    -0.248238, -0.080268, -0.164996, 0.352289,  -0.054999, -0.056485, 0.010405,
    -0.049053, -0.041621, -0.099592, 0.013378,  -0.089187, 0.057972,  -0.413234,
    0.217022,  0.013378,  -0.080268, -0.035675, 0.035675,  0.007432,  0.002973,
    -0.469719, 0.141213,  0.136754,  0.153105,  0.130808,  -0.104052, -0.508367,
    -0.291345, -0.072836, -0.019324, -0.252697, -0.214049, -0.214049, 0.130808,
    0.484583};

// Bias of shape {64}
const float bias_data_16x1x1[] = {
    -0.245395, -0.083545, -0.262522, -0.407912, -0.560898, -0.364789, -0.037964,
    -0.378594, 0.178152,  0.400380,  -0.301349, -0.240913, -0.159454, -0.158757,
    -0.073665, 0.455906,  -0.061232, 0.318907,  -0.226993, -0.344644, 0.140316,
    0.559608,  0.109774,  0.437391,  0.113849,  -0.162068, 0.039572,  0.569472,
    0.460205,  0.113459,  0.370469,  0.176811,  0.203063,  -0.296975, -0.271655,
    0.059862,  -0.159912, -0.077310, -0.338314, -0.195477, -0.256762, 0.233834,
    0.083172,  0.029040,  -0.236288, -0.267054, -0.166627, 0.188319,  -0.271391,
    -0.222920, 0.106463,  0.263614,  0.384986,  -0.125957, -0.095890, 0.363686,
    -0.036990, -0.358884, -0.178254, 0.305596,  0.390088,  -0.189437, 0.613409,
    0.399639};

// Activation state with shape {64, 8}. These initial values must be copied into
// a mutable activation state tensor.
const float initial_activation_state_data_16x1x1[] = {
    -0.582275, -0.586623, -1.262373, -1.277279, -1.542175, -1.271999, -1.429757,
    -1.184425, -0.462094, -1.443421, 0.230736,  -0.494701, -0.354955, -2.534061,
    -4.277471, -4.218467, 0.403711,  -0.248748, -0.330111, -0.467683, 0.549047,
    0.733511,  -0.230115, 0.793136,  -1.126353, -0.984123, -0.081984, -0.222351,
    0.692830,  0.517060,  1.367958,  2.118860,  -0.116766, -0.826365, -2.402700,
    -2.313884, -2.898954, -2.076005, -2.405185, -2.755481, 0.329490,  0.085400,
    -1.485966, -2.034702, -2.161405, -1.269515, -1.151818, -1.823841, 0.561469,
    1.109273,  1.693411,  -0.082605, -0.069252, -1.225107, -1.330693, -1.411435,
    0.253406,  -0.357439, -1.593415, -0.879779, -1.111136, 1.821357,  2.471952,
    1.236908,  -4.014127, -2.810448, -2.944604, -1.930980, -1.566398, -0.838166,
    -0.319242, 0.749349,  1.156476,  0.658670,  1.997437,  2.080663,  2.912618,
    2.677224,  2.642442,  2.796163,  -0.272349, -0.473273, 3.120063,  2.747097,
    3.595510,  1.874150,  2.049919,  2.093396,  -1.049959, 0.277939,  -1.255541,
    -1.052443, -1.810177, -0.883505, -0.538178, 0.524203,  -1.017662, -0.269244,
    0.039129,  -0.227941, -0.114592, -2.018243, -2.548968, -0.706804, 0.890959,
    0.102480,  0.349986,  0.405885,  1.287216,  0.756181,  0.319242,  -0.641590,
    -3.841774, -2.716042, -4.342065, -3.826557, -2.924729, -1.643724, -1.237839,
    -0.597492, -1.954892, -1.215169, -1.528201, -1.018904, -0.863941, -0.293467,
    0.039439,  0.672023,  1.408019,  1.362679,  1.467644,  1.006171,  0.310236,
    -0.249990, -1.048406, -0.752144, -1.831605, -1.058033, -1.096541, -0.293467,
    0.051551,  0.232600,  0.088816,  2.570395,  0.704009,  2.465120,  3.010751,
    2.139357,  0.630410,  1.006171,  1.545281,  1.486898,  -1.162998, -2.344317,
    -4.593918, -3.522842, -2.872247, -1.416714, -0.642521, -0.230115, 0.315205,
    -0.368930, -0.162726, 0.396879,  0.505570,  0.534451,  0.554947,  1.270447,
    0.388805,  0.531967,  -1.243119, -0.671713, -1.214859, -0.238189, 0.016459,
    -1.164550, 0.609603,  3.293348,  2.600208,  1.454290,  -1.034121, -1.760179,
    -1.192500, -0.613951, 3.449553,  2.912618,  1.917937,  1.435968,  0.879158,
    1.118279,  0.102791,  -0.502465, -0.239121, -0.092853, 1.786265,  1.943091,
    2.547104,  2.630641,  2.585302,  2.965411,  -0.945615, -2.538720, -2.474126,
    -1.088156, 0.056209,  0.864873,  0.170490,  0.457435,  0.545941,  0.752765,
    1.569503,  1.129459,  0.662086,  -0.527929, -0.810838, -1.662978, 1.285042,
    1.653040,  4.130893,  2.961995,  4.147041,  3.256393,  3.881524,  2.522571,
    -0.875431, -1.112378, 2.105817,  2.180970,  3.121926,  1.577577,  1.639376,
    2.906407,  -0.142230, 0.421101,  2.212335,  2.311399,  3.993321,  3.651719,
    4.206666,  4.678387,  -1.304917, -1.130701, -2.543067, -2.500212, -2.197118,
    -1.197158, -0.949652, -0.282908, 0.320795,  -1.543728, 1.290322,  1.788128,
    3.957297,  3.205774,  2.892432,  2.297114,  0.138814,  -0.139435, 0.936920,
    0.344707,  0.723263,  -1.772290, -3.138385, -2.287177, -2.405806, -1.859864,
    -4.572801, -3.410424, -3.855748, -2.239663, -2.269786, -1.582857, 4.238342,
    3.858543,  2.499901,  1.087535,  0.290051,  -0.026086, -0.880400, -2.602692,
    -1.404292, 0.253096,  -0.665502, -1.443421, -0.925119, -0.096580, 1.115484,
    1.846200,  -1.604284, -1.244671, -0.464888, 0.326385,  0.168006,  -0.262723,
    -0.744691, 0.953379,  -0.407127, -0.349986, -1.154302, 0.831023,  1.590931,
    2.538720,  2.063583,  3.697680,  -0.752455, -1.293117, -1.330693, -1.869802,
    -0.592523, 0.631652,  1.198089,  -0.481347, 3.738983,  4.153252,  2.782499,
    2.244321,  0.709289,  1.650245,  1.700865,  0.385078,  2.192460,  2.610456,
    4.009780,  3.492719,  2.574743,  2.116687,  1.856138,  1.205853,  2.722563,
    4.075305,  5.415935,  3.009198,  2.715421,  1.571056,  0.897170,  -2.430339,
    0.749970,  0.425760,  -0.302783, 0.817359,  1.031636,  1.913589,  2.686229,
    1.631923,  -1.459259, -1.793097, -1.187531, -1.553355, -0.844998, -1.296843,
    -1.805519, -0.486627, 0.909591,  2.082837,  -1.473855, -2.456735, -3.851401,
    -2.760139, -3.060438, -2.605487, -2.138735, -2.441519, -1.333177, -1.353984,
    -0.245642, -0.588486, 0.033850,  2.084700,  0.076084,  0.690035,  0.747797,
    0.594697,  -1.016109, -1.348083, -1.201195, -1.088466, 2.045571,  2.460772,
    0.717984,  0.041613,  -0.721711, 1.134738,  2.322269,  1.112378,  -0.307441,
    -0.581033, -0.868599, -0.018633, 0.856488,  0.919839,  0.303094,  -0.433213,
    0.811148,  -0.508986, -1.060828, -1.227591, -1.566087, -1.117968, -1.385038,
    -2.011101, -0.490353, -1.849616, -0.594697, -1.055859, 1.110205,  0.622646,
    0.145957,  0.359303,  1.012072,  0.774814,  -0.400295, -1.484103, -2.007374,
    -1.441247, -0.997787, -0.581033, -0.545941, -0.306510, 0.693451,  0.087264,
    -0.227320, -1.211753, -1.532859, -1.688753, 0.065215,  0.134777,  0.608051,
    -0.393152, -0.214588, -0.635689, -1.499320, 0.069562,  -1.555839, -2.633126,
    -2.966032, -1.550870, -0.101549, 0.874189,  0.436318,  0.299367,  2.289972,
    2.339659,  2.602071,  1.564535,  0.019254,  -0.583207, -1.295912, -2.424749,
    -1.221070, -1.175109, -0.577306, -0.102791, 1.877876,  2.568222,  2.173827,
    3.131243,  2.637784,  2.088737,  3.679047,  3.218506,  2.483442,  1.650556,
    1.363611,  -0.027328, 1.486898,  -0.721711, -3.684327, -3.006093, -3.777491,
    -2.327548, -2.737470, -4.549510, -0.060867, 0.127635,  0.680408,  0.581344,
    0.320174,  -0.403090, -0.838166, 0.293777,  -0.995613, -0.165521, -0.419859,
    1.110515,  1.203679,  1.749931,  2.467294,  4.276539,  0.031055,  -0.967664,
    1.167035,  1.865144,  3.221923,  3.248630,  4.121266,  4.187723,  0.749039,
    -1.571056, 0.785994,  1.568572,  3.759479,  3.588678,  4.116608,  3.864444,
    -0.290051, -0.271107, 0.375140,  0.537556,  0.536314,  0.095959,  0.054656,
    0.088816};

// One output with shape {1, 64}
const float golden_output_16x1x1[] = {
    -0.087914, 1.145864,  -0.418088, -1.556392, -0.925298, 0.205252,  0.289119,
    1.331180,  -0.218010, 0.963057,  -2.225886, 1.248478,  1.448983,  0.355467,
    1.682174,  0.803739,  0.449738,  0.543566,  1.916269,  -2.975136, 0.222774,
    0.241589,  -0.104216, 1.561748,  0.936818,  -0.089907, -0.520117, -0.870353,
    1.606074,  0.895770,  0.521297,  -0.369994, -0.889351, -2.809309, 2.404628,
    1.069754,  -0.195456, -1.105652, 1.272715,  -1.233177, 1.271416,  -1.691805,
    -1.058125, -0.716227, 0.052540,  1.262483,  0.540555,  1.735760,  -0.539197,
    -0.014367, -0.243002, 1.072254,  0.528985,  -0.731151, -1.262649, 2.338702,
    -0.603093, 0.970736,  -3.567897, 0.035085,  -0.201711, -0.550400, 1.545573,
    -1.805005};

// One output with shape {1, 64}
const float golden_output_relu_16x1x1[] = {
    0.000000, 1.145864, 0.000000, 0.000000, 0.000000, 0.205252, 0.289119,
    1.331180, 0.000000, 0.963057, 0.000000, 1.248478, 1.448983, 0.355467,
    1.682174, 0.803739, 0.449738, 0.543566, 1.916269, 0.000000, 0.222774,
    0.241589, 0.000000, 1.561748, 0.936818, 0.000000, 0.000000, 0.000000,
    1.606074, 0.895770, 0.521297, 0.000000, 0.000000, 0.000000, 2.404628,
    1.069754, 0.000000, 0.000000, 1.272715, 0.000000, 1.271416, 0.000000,
    0.000000, 0.000000, 0.052540, 1.262483, 0.540555, 1.735760, 0.000000,
    0.000000, 0.000000, 1.072254, 0.528985, 0.000000, 0.000000, 2.338702,
    0.000000, 0.970736, 0.000000, 0.035085, 0.000000, 0.000000, 1.545573,
    0.000000};

template <typename T>
void ValidateSVDFGoldens(const int batch_size, const int num_units,
                         const int input_size, const int rank,
                         TfLiteTensor* tensors, const int tensor_count,
                         TfLiteFusedActivation activaiton,
                         const T* input_sequences_data,
                         const int input_sequences_len, T* output_data,
                         const T* expected_output, float tolerance = 1e-5f) {
  TfLiteSVDFParams params;
  params.rank = rank;
  params.activation = activaiton;

  int inputs_array_data[] = {5, 0, 1, 2, 3, 4};
  TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);

  int outputs_array_data[] = {1, 5};
  TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);

  const TfLiteRegistration registration = Register_SVDF();
  micro::KernelRunner runner(registration, tensors, tensor_count, inputs_array,
                             outputs_array, &params);

  TfLiteStatus init_and_prepare_status = runner.InitAndPrepare();
  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, init_and_prepare_status);

  // Abort early to make it clear init and prepare failed.
  if (init_and_prepare_status != kTfLiteOk) {
    return;
  }

  int num_inputs = input_sequences_len / (input_size * batch_size);

  for (int i = 0; i < num_inputs; ++i) {
    const T* input_batch_start =
        input_sequences_data + i * input_size * batch_size;

    memcpy(tensors[0].data.raw, input_batch_start, tensors[0].bytes);
    TfLiteStatus status = runner.Invoke();
    TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, status);

    // Only validate outputs when invoke has succeeded.
    if (status == kTfLiteOk) {
      int output_idx = 0;
      int golden_idx = i * batch_size * num_units;
      for (int j = golden_idx; j < golden_idx + batch_size * num_units; ++j) {
        TF_LITE_MICRO_EXPECT_NEAR(expected_output[j], output_data[output_idx],
                                  tolerance);
        output_idx++;
      }
    }
  }
}

#if !defined(XTENSA)  // Needed to avoid build errors from unused functions.
void TestSVDF(const int batch_size, const int num_units, const int input_size,
              const int memory_size, const int rank,
              TfLiteFusedActivation activation, float* input_data,
              const float* feature_weights_data, const float* time_weights_data,
              float* activation_state_data, const float* bias_data,
              float* scratch_data, float* output_data,
              const float* input_sequences_data, int input_sequences_len,
              const float* expected_output, float tolerance = 1e-5f) {
  const int num_filters = num_units * rank;

  int input_dims_arg[] = {2, batch_size, input_size};
  TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_arg);

  int feature_weights_dims_args[] = {2, num_filters, input_size};
  TfLiteIntArray* feature_weights_dims =
      IntArrayFromInts(feature_weights_dims_args);

  int time_weights_dims_args[] = {2, num_filters, memory_size};
  TfLiteIntArray* time_weights_dims = IntArrayFromInts(time_weights_dims_args);

  int activation_state_dims_args[] = {2, batch_size, memory_size * num_filters};
  TfLiteIntArray* activation_state_dims =
      IntArrayFromInts(activation_state_dims_args);

  int bias_dims_args[] = {1, num_units};
  TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_args);

  int output_dims_args[] = {2, batch_size, num_units};
  TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_args);

  const int tensor_count = 6;  // 5 inputs, 1 output
  TfLiteTensor tensors[] = {
      CreateTensor(input_data, input_dims),
      CreateTensor(feature_weights_data, feature_weights_dims),
      CreateTensor(time_weights_data, time_weights_dims),
      CreateTensor(bias_data, bias_dims),
      CreateTensor(activation_state_data, activation_state_dims,
                   /*is_variable=*/true),
      CreateTensor(output_data, output_dims),
  };

  ValidateSVDFGoldens(batch_size, num_units, input_size, rank, tensors,
                      tensor_count, activation, input_sequences_data,
                      input_sequences_len, output_data, expected_output,
                      tolerance);
}
#endif

// The pattern to this method's arguemnts is:
// <kernel metadata>
// for each tensor in
//     {input, feature weights, time weights, bias, activation state, output}:
//   <tensor float values> <tensor quantized buffer> <tensor quantization data>
inline void TestIntegerSVDF(
    const int batch_size, const int num_units, const int input_size,
    const int memory_size, const int rank, TfLiteFusedActivation activation,
    int8_t* input_quantized, float input_scale, int input_zero_point,
    const float* feature_weights_data, int8_t* feature_weights_quantized,
    const float feature_weights_scale, const float* time_weights_data,
    int16_t* time_weights_quantized, float time_weights_scale,
    const float* bias_data, int32_t* bias_quantized,
    const float* initial_activation_state_data,
    int16_t* activation_state_quantized, float activation_state_scale,
    int8_t* output_data, float output_scale, int output_zero_point,
    const float* input_sequences_data, int8_t* input_sequences_quantized,
    const int input_sequences_len, const float* golden_output,
    int8_t* golden_output_quantized, int golden_output_len) {
  const int num_filters = num_units * rank;

  int input_dims_arg[] = {2, batch_size, input_size};
  TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_arg);

  int feature_weights_dims_args[] = {2, num_filters, input_size};
  TfLiteIntArray* feature_weights_dims =
      IntArrayFromInts(feature_weights_dims_args);

  int time_weights_dims_args[] = {2, num_filters, memory_size};
  TfLiteIntArray* time_weights_dims = IntArrayFromInts(time_weights_dims_args);

  int bias_dims_data[] = {1, num_units};
  TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data);

  int activation_state_dims_args[] = {2, batch_size, memory_size * num_filters};
  TfLiteIntArray* activation_state_dims =
      IntArrayFromInts(activation_state_dims_args);

  int output_dims_args[] = {2, batch_size, num_units};
  TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_args);

  const int tensor_count = 6;  // 5 inputs, 1 output

  TfLiteTensor tensors[] = {
      CreateQuantizedTensor(input_quantized, input_dims, input_scale,
                            input_zero_point),
      CreateQuantizedTensor(feature_weights_data, feature_weights_quantized,
                            feature_weights_dims, feature_weights_scale, 0),
      CreateQuantizedTensor(time_weights_data, time_weights_quantized,
                            time_weights_dims, time_weights_scale, 0),
      CreateQuantizedBiasTensor(bias_data, bias_quantized, bias_dims,
                                time_weights_scale, activation_state_scale),
      CreateQuantizedTensor(initial_activation_state_data,
                            activation_state_quantized, activation_state_dims,
                            activation_state_scale, 0,
                            /*is_variable=*/true),
      CreateQuantizedTensor(output_data, output_dims, output_scale,
                            output_zero_point)};

  tflite::Quantize(golden_output, golden_output_quantized, golden_output_len,
                   output_scale, output_zero_point);
  tflite::Quantize(input_sequences_data, input_sequences_quantized,
                   input_sequences_len, input_scale, input_zero_point);

  ValidateSVDFGoldens(batch_size, num_units, input_size, rank, tensors,
                      tensor_count, activation, input_sequences_quantized,
                      input_sequences_len, output_data, golden_output_quantized,
                      /*tolerance*/ 1);
}

}  // namespace
}  // namespace testing
}  // namespace tflite

TF_LITE_MICRO_TESTS_BEGIN

#if !defined(XTENSA)  // TODO(b/170332589): xtensa kernels are less general than
                      // reference kernels and we ifdef out test cases that are
                      // currently known to fail.
TF_LITE_MICRO_TEST(SvdfFloat2x2Input2x4OutputShouldMatchGolden) {
  constexpr int batch_size = 2;
  constexpr int num_units = 4;
  constexpr int input_size = 2;
  constexpr int memory_size = 10;
  constexpr int rank = 2;
  constexpr int num_filters = num_units * rank;

  const int input_size_dims_count = batch_size * input_size;
  float input_data[input_size_dims_count];

  const int activation_state_dims_count =
      batch_size * memory_size * num_filters;
  float activation_state_data[activation_state_dims_count];

  memcpy(activation_state_data,
         tflite::testing::initial_activation_state_data_2x2x10,
         sizeof(tflite::testing::initial_activation_state_data_2x2x10));

  const int scratch_dims_count = batch_size * num_filters;
  float scratch_data[scratch_dims_count];

  const int output_dims_count = batch_size * num_units;
  float output_data[output_dims_count];

  tflite::testing::TestSVDF(
      batch_size, num_units, input_size, memory_size, rank, kTfLiteActNone,
      input_data, tflite::testing::feature_weights_data_2x2x10,
      tflite::testing::time_weights_data_2x2x10, activation_state_data,
      tflite::testing::bias_data_2x2x10, scratch_data, output_data,
      tflite::testing::input_data_2x2x10,
      sizeof(tflite::testing::input_data_2x2x10) / sizeof(float),
      tflite::testing::golden_output_2x2x10);
}
#endif

TF_LITE_MICRO_TEST(SvdfQuantized2x2Input2x4OutputShouldMatchGolden) {
  constexpr int batch_size = 2;
  constexpr int num_units = 4;
  constexpr int input_size = 2;
  constexpr int memory_size = 10;
  constexpr int rank = 2;
  constexpr int num_filters = num_units * rank;

  const int input_size_dims_count = batch_size * input_size;

  const int activation_state_dims_count =
      batch_size * memory_size * num_filters;

  const int output_dims_count = batch_size * num_units;
  int8_t output_data[output_dims_count];

  float input_scale = 2.5f / INT8_MAX;              // Range is [-2.5, 2.5]
  float feature_weights_scale = 1.f / INT8_MAX;     // Range is [-1, 1]
  float time_weights_scale = 1.f / INT16_MAX;       // Range is [-1, 1]
  float activation_state_scale = 16.f / INT16_MAX;  // Range is [-16, 16]
  float output_scale = 1.f / INT8_MAX;              // Range is [-1, 1]

  int input_zero_point = 0;
  int output_zero_point = 0;

  int8_t input_quantized[input_size_dims_count];
  int8_t input_sequences_quantized[sizeof(tflite::testing::input_data_2x2x10) /
                                   sizeof(float)];
  int8_t feature_weights_quantized
      [sizeof(tflite::testing::feature_weights_data_2x2x10) / sizeof(float)];
  int16_t
      time_weights_quantized[sizeof(tflite::testing::time_weights_data_2x2x10) /
                             sizeof(float)];
  int16_t activation_state_quantized[activation_state_dims_count];
  int32_t
      bias_quantized[sizeof(tflite::testing::bias_data_2x2x10) / sizeof(float)];
  int8_t golden_quantized[sizeof(tflite::testing::golden_output_2x2x10) /
                          sizeof(float)];

  tflite::testing::TestIntegerSVDF(
      batch_size, num_units, input_size, memory_size, rank, kTfLiteActRelu,
      input_quantized, input_scale, input_zero_point,
      tflite::testing::feature_weights_data_2x2x10, feature_weights_quantized,
      feature_weights_scale, tflite::testing::time_weights_data_2x2x10,
      time_weights_quantized, time_weights_scale,
      tflite::testing::bias_data_2x2x10, bias_quantized,
      tflite::testing::initial_activation_state_data_2x2x10,
      activation_state_quantized, activation_state_scale, output_data,
      output_scale, output_zero_point, tflite::testing::input_data_2x2x10,
      input_sequences_quantized,
      sizeof(tflite::testing::input_data_2x2x10) / sizeof(float),
      tflite::testing::golden_output_2x2x10, golden_quantized,
      sizeof(tflite::testing::golden_output_2x2x10) / sizeof(float));
}

#if !defined(XTENSA)  // TODO(b/170332589): xtensa kernels are less general than
                      // reference kernels and we ifdef out test cases that are
                      // currently known to fail.
TF_LITE_MICRO_TEST(SvdfFloat1x16Input64x1OutputShouldMatchGolden) {
  constexpr int batch_size = 1;
  constexpr int num_units = 64;
  constexpr int input_size = 16;
  constexpr int memory_size = 8;
  constexpr int rank = 1;
  constexpr int num_filters = num_units * rank;
  constexpr int activation_state_dims_count =
      batch_size * memory_size * num_filters;
  constexpr int output_dims_count = batch_size * num_units;
  constexpr int input_dims_count = batch_size * input_size;

  float input_data[input_dims_count];
  float output_data[output_dims_count];
  float scratch_buffer[batch_size * num_filters];
  float activation_state_data_mutable[activation_state_dims_count];

  // Initialize activation state to starting values.
  memcpy(activation_state_data_mutable,
         tflite::testing::initial_activation_state_data_16x1x1,
         sizeof(tflite::testing::initial_activation_state_data_16x1x1));

  tflite::testing::TestSVDF(
      batch_size, num_units, input_size, memory_size, rank, kTfLiteActNone,
      input_data, tflite::testing::feature_weights_data_16x1x1,
      tflite::testing::time_weights_data_16x1x1, activation_state_data_mutable,
      tflite::testing::bias_data_16x1x1, scratch_buffer, output_data,
      tflite::testing::input_data_16x1x1, input_size,
      tflite::testing::golden_output_16x1x1);
}

TF_LITE_MICRO_TEST(SvdfFloat1x16Input64x1OutputReluShouldMatchGolden) {
  constexpr int batch_size = 1;
  constexpr int num_units = 64;
  constexpr int input_size = 16;
  constexpr int memory_size = 8;
  constexpr int rank = 1;
  constexpr int num_filters = num_units * rank;
  constexpr int activation_state_dims_count =
      batch_size * memory_size * num_filters;
  constexpr int output_dims_count = batch_size * num_units;
  constexpr int input_dims_count = batch_size * input_size;

  float input_data[input_dims_count];
  float output_data[output_dims_count];
  float scratch_buffer[batch_size * num_filters];
  float activation_state_data_mutable[activation_state_dims_count];

  // Initialize activation state to starting values.
  memcpy(activation_state_data_mutable,
         tflite::testing::initial_activation_state_data_16x1x1,
         sizeof(tflite::testing::initial_activation_state_data_16x1x1));

  tflite::testing::TestSVDF(
      batch_size, num_units, input_size, memory_size, rank, kTfLiteActRelu,
      input_data, tflite::testing::feature_weights_data_16x1x1,
      tflite::testing::time_weights_data_16x1x1, activation_state_data_mutable,
      tflite::testing::bias_data_16x1x1, scratch_buffer, output_data,
      tflite::testing::input_data_16x1x1, input_size,
      tflite::testing::golden_output_relu_16x1x1);
}
#endif

TF_LITE_MICRO_TEST(SvdfQuantized1x16Input64x1OutputShouldMatchGolden) {
  constexpr int batch_size = 1;
  constexpr int num_units = 64;
  constexpr int input_size = 16;
  constexpr int memory_size = 8;
  constexpr int rank = 1;
  constexpr int num_filters = num_units * rank;
  constexpr int activation_state_dims_count =
      batch_size * memory_size * num_filters;
  constexpr int output_dims_count = batch_size * num_units;
  constexpr int input_dims_count = batch_size * input_size;

  int8_t output_data[output_dims_count];

  float input_scale = 0.10075444;
  float feature_weights_scale = 0.00649388;
  float time_weights_scale = 0.001571355;
  float activation_state_scale = 0.00045896982;
  float output_scale = 0.051445257;

  int input_zero_point = 2;
  int output_zero_point = 0;

  int8_t input_quantized[input_dims_count];
  int8_t input_sequences_quantized[sizeof(tflite::testing::input_data_16x1x1) /
                                   sizeof(float)];
  int8_t feature_weights_quantized
      [sizeof(tflite::testing::feature_weights_data_16x1x1) / sizeof(float)];
  int16_t
      time_weights_quantized[sizeof(tflite::testing::time_weights_data_16x1x1) /
                             sizeof(float)];
  int16_t activation_state_quantized[activation_state_dims_count];
  int32_t
      bias_quantized[sizeof(tflite::testing::bias_data_16x1x1) / sizeof(float)];
  int8_t golden_quantized[sizeof(tflite::testing::golden_output_16x1x1) /
                          sizeof(float)];

  tflite::testing::TestIntegerSVDF(
      batch_size, num_units, input_size, memory_size, rank, kTfLiteActNone,
      input_quantized, input_scale, input_zero_point,
      tflite::testing::feature_weights_data_16x1x1, feature_weights_quantized,
      feature_weights_scale, tflite::testing::time_weights_data_16x1x1,
      time_weights_quantized, time_weights_scale,
      tflite::testing::bias_data_16x1x1, bias_quantized,
      tflite::testing::initial_activation_state_data_16x1x1,
      activation_state_quantized, activation_state_scale, output_data,
      output_scale, output_zero_point, tflite::testing::input_data_16x1x1,
      input_sequences_quantized,
      sizeof(tflite::testing::input_data_16x1x1) / sizeof(float),
      tflite::testing::golden_output_16x1x1, golden_quantized,
      sizeof(tflite::testing::golden_output_16x1x1) / sizeof(float));
}

TF_LITE_MICRO_TEST(SvdfQuantized1x16Input64x1OutputReluShouldMatchGolden) {
  constexpr int batch_size = 1;
  constexpr int num_units = 64;
  constexpr int input_size = 16;
  constexpr int memory_size = 8;
  constexpr int rank = 1;
  constexpr int num_filters = num_units * rank;
  constexpr int activation_state_dims_count =
      batch_size * memory_size * num_filters;
  constexpr int output_dims_count = batch_size * num_units;
  constexpr int input_dims_count = batch_size * input_size;

  int8_t output_data[output_dims_count];

  float input_scale = 0.10075444;
  float feature_weights_scale = 0.00649388;
  float time_weights_scale = 0.001571355;
  float activation_state_scale = 0.00045896982;
  float output_scale = 0.051445257;

  int input_zero_point = 2;
  int output_zero_point = -128;

  int8_t input_quantized[input_dims_count];
  int8_t input_sequences_quantized[sizeof(tflite::testing::input_data_16x1x1) /
                                   sizeof(float)];
  int8_t feature_weights_quantized
      [sizeof(tflite::testing::feature_weights_data_16x1x1) / sizeof(float)];
  int16_t
      time_weights_quantized[sizeof(tflite::testing::time_weights_data_16x1x1) /
                             sizeof(float)];
  int16_t activation_state_quantized[activation_state_dims_count];
  int32_t
      bias_quantized[sizeof(tflite::testing::bias_data_16x1x1) / sizeof(float)];
  int8_t golden_quantized[sizeof(tflite::testing::golden_output_relu_16x1x1) /
                          sizeof(float)];

  tflite::testing::TestIntegerSVDF(
      batch_size, num_units, input_size, memory_size, rank, kTfLiteActRelu,
      input_quantized, input_scale, input_zero_point,
      tflite::testing::feature_weights_data_16x1x1, feature_weights_quantized,
      feature_weights_scale, tflite::testing::time_weights_data_16x1x1,
      time_weights_quantized, time_weights_scale,
      tflite::testing::bias_data_16x1x1, bias_quantized,
      tflite::testing::initial_activation_state_data_16x1x1,
      activation_state_quantized, activation_state_scale, output_data,
      output_scale, output_zero_point, tflite::testing::input_data_16x1x1,
      input_sequences_quantized,
      sizeof(tflite::testing::input_data_16x1x1) / sizeof(float),
      tflite::testing::golden_output_relu_16x1x1, golden_quantized,
      sizeof(tflite::testing::golden_output_relu_16x1x1) / sizeof(float));
}

TF_LITE_MICRO_TESTS_END
